import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from torch.nn.parameter import Parameter
import numpy as np
from collections import OrderedDict
from torch.nn import Parameter
from networks import deeplab_xception, gcn
import pdb

#######################
# base model
#######################


class deeplab_xception_transfer_basemodel(deeplab_xception.DeepLabv3_plus):
    def __init__(
        self,
        nInputChannels=3,
        n_classes=7,
        os=16,
        input_channels=256,
        hidden_layers=128,
        out_channels=256,
    ):
        super(deeplab_xception_transfer_basemodel, self).__init__(
            nInputChannels=nInputChannels,
            n_classes=n_classes,
            os=os,
        )
        ### source graph
        # self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
        #                                                    nodes=n_classes)
        # self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
        # self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
        # self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
        #
        # self.source_graph_2_fea = gcn.Graph_to_Featuremaps(input_channels=input_channels, output_channels=out_channels,
        #                                             hidden_layers=hidden_layers, nodes=n_classes
        #                                             )
        # self.source_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
        #                                  nn.ReLU(True)])

        ### target graph
        self.target_featuremap_2_graph = gcn.Featuremaps_to_Graph(
            input_channels=input_channels, hidden_layers=hidden_layers, nodes=n_classes
        )
        self.target_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
        self.target_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
        self.target_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)

        self.target_graph_2_fea = gcn.Graph_to_Featuremaps(
            input_channels=input_channels,
            output_channels=out_channels,
            hidden_layers=hidden_layers,
            nodes=n_classes,
        )
        self.target_skip_conv = nn.Sequential(
            *[nn.Conv2d(input_channels, input_channels, kernel_size=1), nn.ReLU(True)]
        )

    def load_source_model(self, state_dict):
        own_state = self.state_dict()
        # for name inshop_cos own_state:
        #    print name
        new_state_dict = OrderedDict()
        for name, param in state_dict.items():
            name = name.replace("module.", "")
            if (
                "graph" in name
                and "source" not in name
                and "target" not in name
                and "fc_graph" not in name
                and "transpose_graph" not in name
            ):
                if "featuremap_2_graph" in name:
                    name = name.replace(
                        "featuremap_2_graph", "source_featuremap_2_graph"
                    )
                else:
                    name = name.replace("graph", "source_graph")
            new_state_dict[name] = 0
            if name not in own_state:
                if "num_batch" in name:
                    continue
                print('unexpected key "{}" in state_dict'.format(name))
                continue
                # if isinstance(param, own_state):
            if isinstance(param, Parameter):
                # backwards compatibility for serialized parameters
                param = param.data
            try:
                own_state[name].copy_(param)
            except:
                print(
                    "While copying the parameter named {}, whose dimensions in the model are"
                    " {} and whose dimensions in the checkpoint are {}, ...".format(
                        name, own_state[name].size(), param.size()
                    )
                )
                continue  # i add inshop_cos 2018/02/01
            own_state[name].copy_(param)
            # print 'copying %s' %name

        missing = set(own_state.keys()) - set(new_state_dict.keys())
        if len(missing) > 0:
            print('missing keys in state_dict: "{}"'.format(missing))

    def get_target_parameter(self):
        l = []
        other = []
        for name, k in self.named_parameters():
            if "target" in name or "semantic" in name:
                l.append(k)
            else:
                other.append(k)
        return l, other

    def get_semantic_parameter(self):
        l = []
        for name, k in self.named_parameters():
            if "semantic" in name:
                l.append(k)
        return l

    def get_source_parameter(self):
        l = []
        for name, k in self.named_parameters():
            if "source" in name:
                l.append(k)
        return l

    def forward(self, input, adj1_target=None, adj2_source=None, adj3_transfer=None):
        x, low_level_features = self.xception_features(input)
        # print(x.size())
        x1 = self.aspp1(x)
        x2 = self.aspp2(x)
        x3 = self.aspp3(x)
        x4 = self.aspp4(x)
        x5 = self.global_avg_pool(x)
        x5 = F.upsample(x5, size=x4.size()[2:], mode="bilinear", align_corners=True)

        x = torch.cat((x1, x2, x3, x4, x5), dim=1)

        x = self.concat_projection_conv1(x)
        x = self.concat_projection_bn1(x)
        x = self.relu(x)
        # print(x.size())
        x = F.upsample(
            x, size=low_level_features.size()[2:], mode="bilinear", align_corners=True
        )

        low_level_features = self.feature_projection_conv1(low_level_features)
        low_level_features = self.feature_projection_bn1(low_level_features)
        low_level_features = self.relu(low_level_features)
        # print(low_level_features.size())
        # print(x.size())
        x = torch.cat((x, low_level_features), dim=1)
        x = self.decoder(x)

        ### add graph

        # target graph
        # print('x size',x.size(),adj1.size())
        graph = self.target_featuremap_2_graph(x)

        # graph combine
        # print(graph.size(),source_2_target_graph.size())
        # graph = self.fc_graph.forward(graph,relu=True)
        # print(graph.size())

        graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
        graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
        graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
        # print(graph.size(),x.size())
        # graph = self.gcn_encode.forward(graph,relu=True)
        # graph = self.graph_conv2.forward(graph,adj=adj2,relu=True)
        # graph = self.gcn_decode.forward(graph,relu=True)
        graph = self.target_graph_2_fea.forward(graph, x)
        x = self.target_skip_conv(x)
        x = x + graph

        ###
        x = self.semantic(x)
        x = F.upsample(x, size=input.size()[2:], mode="bilinear", align_corners=True)

        return x


class deeplab_xception_transfer_basemodel_savememory(deeplab_xception.DeepLabv3_plus):
    def __init__(
        self,
        nInputChannels=3,
        n_classes=7,
        os=16,
        input_channels=256,
        hidden_layers=128,
        out_channels=256,
    ):
        super(deeplab_xception_transfer_basemodel_savememory, self).__init__(
            nInputChannels=nInputChannels,
            n_classes=n_classes,
            os=os,
        )
        ### source graph

        ### target graph
        self.target_featuremap_2_graph = gcn.Featuremaps_to_Graph(
            input_channels=input_channels, hidden_layers=hidden_layers, nodes=n_classes
        )
        self.target_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
        self.target_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
        self.target_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)

        self.target_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(
            input_channels=input_channels,
            output_channels=out_channels,
            hidden_layers=hidden_layers,
            nodes=n_classes,
        )
        self.target_skip_conv = nn.Sequential(
            *[nn.Conv2d(input_channels, input_channels, kernel_size=1), nn.ReLU(True)]
        )

    def load_source_model(self, state_dict):
        own_state = self.state_dict()
        # for name inshop_cos own_state:
        #    print name
        new_state_dict = OrderedDict()
        for name, param in state_dict.items():
            name = name.replace("module.", "")
            if (
                "graph" in name
                and "source" not in name
                and "target" not in name
                and "fc_graph" not in name
                and "transpose_graph" not in name
            ):
                if "featuremap_2_graph" in name:
                    name = name.replace(
                        "featuremap_2_graph", "source_featuremap_2_graph"
                    )
                else:
                    name = name.replace("graph", "source_graph")
            new_state_dict[name] = 0
            if name not in own_state:
                if "num_batch" in name:
                    continue
                print('unexpected key "{}" in state_dict'.format(name))
                continue
                # if isinstance(param, own_state):
            if isinstance(param, Parameter):
                # backwards compatibility for serialized parameters
                param = param.data
            try:
                own_state[name].copy_(param)
            except:
                print(
                    "While copying the parameter named {}, whose dimensions in the model are"
                    " {} and whose dimensions in the checkpoint are {}, ...".format(
                        name, own_state[name].size(), param.size()
                    )
                )
                continue  # i add inshop_cos 2018/02/01
            own_state[name].copy_(param)
            # print 'copying %s' %name

        missing = set(own_state.keys()) - set(new_state_dict.keys())
        if len(missing) > 0:
            print('missing keys in state_dict: "{}"'.format(missing))

    def get_target_parameter(self):
        l = []
        other = []
        for name, k in self.named_parameters():
            if "target" in name or "semantic" in name:
                l.append(k)
            else:
                other.append(k)
        return l, other

    def get_semantic_parameter(self):
        l = []
        for name, k in self.named_parameters():
            if "semantic" in name:
                l.append(k)
        return l

    def get_source_parameter(self):
        l = []
        for name, k in self.named_parameters():
            if "source" in name:
                l.append(k)
        return l

    def forward(self, input, adj1_target=None, adj2_source=None, adj3_transfer=None):
        x, low_level_features = self.xception_features(input)
        # print(x.size())
        x1 = self.aspp1(x)
        x2 = self.aspp2(x)
        x3 = self.aspp3(x)
        x4 = self.aspp4(x)
        x5 = self.global_avg_pool(x)
        x5 = F.upsample(x5, size=x4.size()[2:], mode="bilinear", align_corners=True)

        x = torch.cat((x1, x2, x3, x4, x5), dim=1)

        x = self.concat_projection_conv1(x)
        x = self.concat_projection_bn1(x)
        x = self.relu(x)
        # print(x.size())
        x = F.upsample(
            x, size=low_level_features.size()[2:], mode="bilinear", align_corners=True
        )

        low_level_features = self.feature_projection_conv1(low_level_features)
        low_level_features = self.feature_projection_bn1(low_level_features)
        low_level_features = self.relu(low_level_features)
        # print(low_level_features.size())
        # print(x.size())
        x = torch.cat((x, low_level_features), dim=1)
        x = self.decoder(x)

        ### add graph

        # target graph
        # print('x size',x.size(),adj1.size())
        graph = self.target_featuremap_2_graph(x)

        # graph combine
        # print(graph.size(),source_2_target_graph.size())
        # graph = self.fc_graph.forward(graph,relu=True)
        # print(graph.size())

        graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
        graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
        graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
        # print(graph.size(),x.size())
        # graph = self.gcn_encode.forward(graph,relu=True)
        # graph = self.graph_conv2.forward(graph,adj=adj2,relu=True)
        # graph = self.gcn_decode.forward(graph,relu=True)
        graph = self.target_graph_2_fea.forward(graph, x)
        x = self.target_skip_conv(x)
        x = x + graph

        ###
        x = self.semantic(x)
        x = F.upsample(x, size=input.size()[2:], mode="bilinear", align_corners=True)

        return x


#######################
# transfer model
#######################


class deeplab_xception_transfer_projection(deeplab_xception_transfer_basemodel):
    def __init__(
        self,
        nInputChannels=3,
        n_classes=7,
        os=16,
        input_channels=256,
        hidden_layers=128,
        out_channels=256,
        transfer_graph=None,
        source_classes=20,
    ):
        super(deeplab_xception_transfer_projection, self).__init__(
            nInputChannels=nInputChannels,
            n_classes=n_classes,
            os=os,
            input_channels=input_channels,
            hidden_layers=hidden_layers,
            out_channels=out_channels,
        )
        self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(
            input_channels=input_channels,
            hidden_layers=hidden_layers,
            nodes=source_classes,
        )
        self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
        self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
        self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
        self.transpose_graph = gcn.Graph_trans(
            in_features=hidden_layers,
            out_features=hidden_layers,
            adj=transfer_graph,
            begin_nodes=source_classes,
            end_nodes=n_classes,
        )
        self.fc_graph = gcn.GraphConvolution(hidden_layers * 3, hidden_layers)

    def forward(self, input, adj1_target=None, adj2_source=None, adj3_transfer=None):
        x, low_level_features = self.xception_features(input)
        # print(x.size())
        x1 = self.aspp1(x)
        x2 = self.aspp2(x)
        x3 = self.aspp3(x)
        x4 = self.aspp4(x)
        x5 = self.global_avg_pool(x)
        x5 = F.upsample(x5, size=x4.size()[2:], mode="bilinear", align_corners=True)

        x = torch.cat((x1, x2, x3, x4, x5), dim=1)

        x = self.concat_projection_conv1(x)
        x = self.concat_projection_bn1(x)
        x = self.relu(x)
        # print(x.size())
        x = F.upsample(
            x, size=low_level_features.size()[2:], mode="bilinear", align_corners=True
        )

        low_level_features = self.feature_projection_conv1(low_level_features)
        low_level_features = self.feature_projection_bn1(low_level_features)
        low_level_features = self.relu(low_level_features)
        # print(low_level_features.size())
        # print(x.size())
        x = torch.cat((x, low_level_features), dim=1)
        x = self.decoder(x)

        ### add graph
        # source graph
        source_graph = self.source_featuremap_2_graph(x)
        source_graph1 = self.source_graph_conv1.forward(
            source_graph, adj=adj2_source, relu=True
        )
        source_graph2 = self.source_graph_conv2.forward(
            source_graph1, adj=adj2_source, relu=True
        )
        source_graph3 = self.source_graph_conv2.forward(
            source_graph2, adj=adj2_source, relu=True
        )

        source_2_target_graph1_v5 = self.transpose_graph.forward(
            source_graph1, adj=adj3_transfer, relu=True
        )
        source_2_target_graph2_v5 = self.transpose_graph.forward(
            source_graph2, adj=adj3_transfer, relu=True
        )
        source_2_target_graph3_v5 = self.transpose_graph.forward(
            source_graph3, adj=adj3_transfer, relu=True
        )

        # target graph
        # print('x size',x.size(),adj1.size())
        graph = self.target_featuremap_2_graph(x)

        source_2_target_graph1 = self.similarity_trans(source_graph1, graph)
        # graph combine 1
        # print(graph.size())
        # print(source_2_target_graph1.size())
        # print(source_2_target_graph1_v5.size())
        graph = torch.cat(
            (
                graph,
                source_2_target_graph1.squeeze(0),
                source_2_target_graph1_v5.squeeze(0),
            ),
            dim=-1,
        )
        graph = self.fc_graph.forward(graph, relu=True)

        graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)

        source_2_target_graph2 = self.similarity_trans(source_graph2, graph)
        # graph combine 2
        graph = torch.cat(
            (graph, source_2_target_graph2, source_2_target_graph2_v5), dim=-1
        )
        graph = self.fc_graph.forward(graph, relu=True)

        graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)

        source_2_target_graph3 = self.similarity_trans(source_graph3, graph)
        # graph combine 3
        graph = torch.cat(
            (graph, source_2_target_graph3, source_2_target_graph3_v5), dim=-1
        )
        graph = self.fc_graph.forward(graph, relu=True)

        graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)

        # print(graph.size(),x.size())

        graph = self.target_graph_2_fea.forward(graph, x)
        x = self.target_skip_conv(x)
        x = x + graph

        ###
        x = self.semantic(x)
        x = F.upsample(x, size=input.size()[2:], mode="bilinear", align_corners=True)

        return x

    def similarity_trans(self, source, target):
        sim = torch.matmul(
            F.normalize(target, p=2, dim=-1),
            F.normalize(source, p=2, dim=-1).transpose(-1, -2),
        )
        sim = F.softmax(sim, dim=-1)
        return torch.matmul(sim, source)

    def load_source_model(self, state_dict):
        own_state = self.state_dict()
        # for name inshop_cos own_state:
        #    print name
        new_state_dict = OrderedDict()
        for name, param in state_dict.items():
            name = name.replace("module.", "")

            if (
                "graph" in name
                and "source" not in name
                and "target" not in name
                and "fc_" not in name
                and "transpose_graph" not in name
            ):
                if "featuremap_2_graph" in name:
                    name = name.replace(
                        "featuremap_2_graph", "source_featuremap_2_graph"
                    )
                else:
                    name = name.replace("graph", "source_graph")
            new_state_dict[name] = 0
            if name not in own_state:
                if "num_batch" in name:
                    continue
                print('unexpected key "{}" in state_dict'.format(name))
                continue
                # if isinstance(param, own_state):
            if isinstance(param, Parameter):
                # backwards compatibility for serialized parameters
                param = param.data
            try:
                own_state[name].copy_(param)
            except:
                print(
                    "While copying the parameter named {}, whose dimensions in the model are"
                    " {} and whose dimensions in the checkpoint are {}, ...".format(
                        name, own_state[name].size(), param.size()
                    )
                )
                continue  # i add inshop_cos 2018/02/01
            own_state[name].copy_(param)
            # print 'copying %s' %name

        missing = set(own_state.keys()) - set(new_state_dict.keys())
        if len(missing) > 0:
            print('missing keys in state_dict: "{}"'.format(missing))


class deeplab_xception_transfer_projection_savemem(
    deeplab_xception_transfer_basemodel_savememory
):
    def __init__(
        self,
        nInputChannels=3,
        n_classes=7,
        os=16,
        input_channels=256,
        hidden_layers=128,
        out_channels=256,
        transfer_graph=None,
        source_classes=20,
    ):
        super(deeplab_xception_transfer_projection_savemem, self).__init__(
            nInputChannels=nInputChannels,
            n_classes=n_classes,
            os=os,
            input_channels=input_channels,
            hidden_layers=hidden_layers,
            out_channels=out_channels,
        )
        self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(
            input_channels=input_channels,
            hidden_layers=hidden_layers,
            nodes=source_classes,
        )
        self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
        self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
        self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
        self.transpose_graph = gcn.Graph_trans(
            in_features=hidden_layers,
            out_features=hidden_layers,
            adj=transfer_graph,
            begin_nodes=source_classes,
            end_nodes=n_classes,
        )
        self.fc_graph = gcn.GraphConvolution(hidden_layers * 3, hidden_layers)

    def forward(self, input, adj1_target=None, adj2_source=None, adj3_transfer=None):
        x, low_level_features = self.xception_features(input)
        # print(x.size())
        x1 = self.aspp1(x)
        x2 = self.aspp2(x)
        x3 = self.aspp3(x)
        x4 = self.aspp4(x)
        x5 = self.global_avg_pool(x)
        x5 = F.upsample(x5, size=x4.size()[2:], mode="bilinear", align_corners=True)

        x = torch.cat((x1, x2, x3, x4, x5), dim=1)

        x = self.concat_projection_conv1(x)
        x = self.concat_projection_bn1(x)
        x = self.relu(x)
        # print(x.size())
        x = F.upsample(
            x, size=low_level_features.size()[2:], mode="bilinear", align_corners=True
        )

        low_level_features = self.feature_projection_conv1(low_level_features)
        low_level_features = self.feature_projection_bn1(low_level_features)
        low_level_features = self.relu(low_level_features)
        # print(low_level_features.size())
        # print(x.size())
        x = torch.cat((x, low_level_features), dim=1)
        x = self.decoder(x)

        ### add graph
        # source graph
        source_graph = self.source_featuremap_2_graph(x)
        source_graph1 = self.source_graph_conv1.forward(
            source_graph, adj=adj2_source, relu=True
        )
        source_graph2 = self.source_graph_conv2.forward(
            source_graph1, adj=adj2_source, relu=True
        )
        source_graph3 = self.source_graph_conv2.forward(
            source_graph2, adj=adj2_source, relu=True
        )

        source_2_target_graph1_v5 = self.transpose_graph.forward(
            source_graph1, adj=adj3_transfer, relu=True
        )
        source_2_target_graph2_v5 = self.transpose_graph.forward(
            source_graph2, adj=adj3_transfer, relu=True
        )
        source_2_target_graph3_v5 = self.transpose_graph.forward(
            source_graph3, adj=adj3_transfer, relu=True
        )

        # target graph
        # print('x size',x.size(),adj1.size())
        graph = self.target_featuremap_2_graph(x)

        source_2_target_graph1 = self.similarity_trans(source_graph1, graph)
        # graph combine 1
        graph = torch.cat(
            (
                graph,
                source_2_target_graph1.squeeze(0),
                source_2_target_graph1_v5.squeeze(0),
            ),
            dim=-1,
        )
        graph = self.fc_graph.forward(graph, relu=True)

        graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)

        source_2_target_graph2 = self.similarity_trans(source_graph2, graph)
        # graph combine 2
        graph = torch.cat(
            (graph, source_2_target_graph2, source_2_target_graph2_v5), dim=-1
        )
        graph = self.fc_graph.forward(graph, relu=True)

        graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)

        source_2_target_graph3 = self.similarity_trans(source_graph3, graph)
        # graph combine 3
        graph = torch.cat(
            (graph, source_2_target_graph3, source_2_target_graph3_v5), dim=-1
        )
        graph = self.fc_graph.forward(graph, relu=True)

        graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)

        # print(graph.size(),x.size())

        graph = self.target_graph_2_fea.forward(graph, x)
        x = self.target_skip_conv(x)
        x = x + graph

        ###
        x = self.semantic(x)
        x = F.upsample(x, size=input.size()[2:], mode="bilinear", align_corners=True)

        return x

    def similarity_trans(self, source, target):
        sim = torch.matmul(
            F.normalize(target, p=2, dim=-1),
            F.normalize(source, p=2, dim=-1).transpose(-1, -2),
        )
        sim = F.softmax(sim, dim=-1)
        return torch.matmul(sim, source)

    def load_source_model(self, state_dict):
        own_state = self.state_dict()
        # for name inshop_cos own_state:
        #    print name
        new_state_dict = OrderedDict()
        for name, param in state_dict.items():
            name = name.replace("module.", "")

            if (
                "graph" in name
                and "source" not in name
                and "target" not in name
                and "fc_" not in name
                and "transpose_graph" not in name
            ):
                if "featuremap_2_graph" in name:
                    name = name.replace(
                        "featuremap_2_graph", "source_featuremap_2_graph"
                    )
                else:
                    name = name.replace("graph", "source_graph")
            new_state_dict[name] = 0
            if name not in own_state:
                if "num_batch" in name:
                    continue
                print('unexpected key "{}" in state_dict'.format(name))
                continue
                # if isinstance(param, own_state):
            if isinstance(param, Parameter):
                # backwards compatibility for serialized parameters
                param = param.data
            try:
                own_state[name].copy_(param)
            except:
                print(
                    "While copying the parameter named {}, whose dimensions in the model are"
                    " {} and whose dimensions in the checkpoint are {}, ...".format(
                        name, own_state[name].size(), param.size()
                    )
                )
                continue  # i add inshop_cos 2018/02/01
            own_state[name].copy_(param)
            # print 'copying %s' %name

        missing = set(own_state.keys()) - set(new_state_dict.keys())
        if len(missing) > 0:
            print('missing keys in state_dict: "{}"'.format(missing))


# if __name__ == '__main__':
# net = deeplab_xception_transfer_projection_v3v5_more_savemem()
# img = torch.rand((2,3,128,128))
# net.eval()
# a = torch.rand((1,1,7,7))
# net.forward(img, adj1_target=a)
